Skip to content

fix: jax reducers returning incorrect output values or lengths #3464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

ikrommyd
Copy link
Collaborator

@ikrommyd ikrommyd commented Apr 14, 2025

Needs more work and I'd appreciate any help @ianna @pfackeldey.
I'm adding a test that tests all the reducers via parametrization.

Definitely needs #3457

This PR should fix #3456, #3462, #3463 and #3465

@ikrommyd
Copy link
Collaborator Author

I'm currently skipping the argmin argmax tests because of #3463 but that should change.

@ikrommyd
Copy link
Collaborator Author

# See issue https://github.com/google/jax/issues/9296
result = jax.numpy.exp(
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data)
)

So we can't take the product of an array with negative numbers? The logarithm will just NaN the output

@ikrommyd
Copy link
Collaborator Author

# See issue https://github.com/google/jax/issues/9296
result = jax.numpy.exp(
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data)
)

So we can't take the product of an array with negative numbers? The logarithm will just NaN the output

I may have fixed that in 9874810. I still haven't made ak.any to behave like the cpu case when segments have zeros though.

@ianna
Copy link
Collaborator

ianna commented Apr 14, 2025

@ikrommyd - impressive! There are only 5 tests that fails: 3 for ak.any, 1 for ak.count, and 1 for ak.sum. It looks like the latter is failing with a boolean dtype. Perhaps, return should make an int?

Copy link
Collaborator

@ianna ianna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into it!

@@ -261,7 +328,7 @@ def apply(
if array.dtype.kind == "M":
raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ikrommyd - if we want to allow sum of boolean, I think, we should view the data as integers here:

    if array.dtype == np.bool_:
        data = array.data.astype(jax.numpy.int32)
    else:
        data = array.data

@ikrommyd
Copy link
Collaborator Author

ikrommyd commented Apr 14, 2025

So I made ci pass but there are a few things that are to be done for sure

  1. The code needs refactoring, I don't like how it looks at all, it's very hacky
  2. The reducer tests should definitely test more input array cases
  3. We need to make sure I'm not breaking something that should work but is untested due to the smaller amount of tests for the jax backend

Copy link
Collaborator

@ianna ianna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ikrommyd - just some minor comments. I agree, JAX backend needs to be thoroughly tested. Perhaps, our fellow could take over? @pfackeldey - when do we discuss his project? Thanks.

@@ -68,7 +131,7 @@ def segment_argmin(data, segment_ids):
class ArgMin(JAXReducer):
name: Final = "argmin"
needs_position: Final = True
preferred_dtype: Final = np.int64
preferred_dtype: Final = np.float64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argmin returns the index (i.e. position) of the minimum value. Indices are always integers, not floating-point numbers.

@@ -125,7 +191,7 @@ def segment_argmax(data, segment_ids):
class ArgMax(JAXReducer):
name: Final = "argmax"
needs_position: Final = True
preferred_dtype: Final = np.int64
preferred_dtype: Final = np.float64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same argument here - indices are always integers, not floating-point numbers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment